/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
#include <executorch/kernels/test/TestUtil.h>
#include <executorch/kernels/test/supported_features.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>

#include <gtest/gtest.h>

using namespace ::testing;

class OpNativeBatchNormLegitNoTrainingOutTest : public OperatorTest {
 protected:
  ::std::tuple<
      executorch::aten::Tensor&,
      executorch::aten::Tensor&,
      executorch::aten::Tensor&>
  op_native_batch_norm_legit_no_training_out(
      const executorch::aten::Tensor& input,
      const std::optional<executorch::aten::Tensor>& weight,
      const std::optional<executorch::aten::Tensor>& bias,
      const executorch::aten::Tensor& running_mean,
      const executorch::aten::Tensor& running_var,
      double momentum,
      double eps,
      executorch::aten::Tensor& out0,
      executorch::aten::Tensor& out1,
      executorch::aten::Tensor& out2) {
    return torch::executor::aten::_native_batch_norm_legit_no_training_outf(
        context_,
        input,
        weight,
        bias,
        running_mean,
        running_var,
        momentum,
        eps,
        out0,
        out1,
        out2);
  }

  template <executorch::aten::ScalarType DTYPE>
  void test_2d_dtype() {
    torch::executor::testing::TensorFactory<DTYPE> tf;

    executorch::aten::Tensor input = tf.make(
        {4, 7}, {2.876736640930176,  7.67944860458374,   5.701690196990967,
                 9.299789428710938,  3.023690700531006,  5.315116882324219,
                 7.185585021972656,  6.911304473876953,  7.61051082611084,
                 1.4963287115097046, 0.7381612062454224, 8.588483810424805,
                 6.583977699279785,  8.831110000610352,  0.8165055513381958,
                 7.087201118469238,  5.572513580322266,  4.446897983551025,
                 4.444573402404785,  6.254056930541992,  5.906398296356201,
                 9.971039772033691,  3.5423521995544434, 7.452159881591797,
                 9.93700122833252,   1.8560808897018433, 1.524025797843933,
                 7.3222975730896});
    std::optional<executorch::aten::Tensor> weight =
        std::optional<executorch::aten::Tensor>(tf.make(
            {7},
            {8.287437438964844,
             8.227645874023438,
             6.65926456451416,
             9.436124801635742,
             4.119281768798828,
             8.593960762023926,
             2.3760855197906494}));
    std::optional<executorch::aten::Tensor> bias =
        std::optional<executorch::aten::Tensor>(tf.make(
            {7},
            {7.824275970458984,
             6.84327507019043,
             8.354326248168945,
             8.773970603942871,
             3.89609694480896,
             3.0753469467163086,
             3.1105971336364746}));
    executorch::aten::Tensor running_mean = tf.make(
        {7},
        {9.700226783752441,
         0.1234668493270874,
         7.527220249176025,
         8.993252754211426,
         0.4736626148223877,
         7.7135701179504395,
         5.12320613861084});
    executorch::aten::Tensor running_var = tf.make(
        {7},
        {3.585531234741211,
         6.615292549133301,
         0.24084866046905518,
         5.175800323486328,
         0.5886000394821167,
         6.23909854888916,
         1.5029621124267578});
    double momentum = 0.1;
    double eps = 0;
    executorch::aten::Tensor out0 = tf.zeros({4, 7});
    executorch::aten::Tensor out1 = tf.zeros({0});
    executorch::aten::Tensor out2 = tf.zeros({0});
    executorch::aten::Tensor out0_expected = tf.make(
        {4, 7}, {-22.039867401123047, 31.014127731323242,  -16.416650772094727,
                 10.04538631439209,   17.5877628326416,    -5.17673921585083,
                 7.1078033447265625,  -4.381907939910889,  30.793603897094727,
                 -73.48003387451172,  -25.46548080444336,  47.46636962890625,
                 -0.8111140131950378, 10.29708194732666,   -31.056814193725586,
                 29.119586944580078,  -18.16947364807129,  -10.082839965820312,
                 25.216796875,        -1.9462348222732544, 4.628543376922607,
                 9.00953483581543,    17.779958724975586,  7.335818767547607,
                 12.688335418701172,  11.318607330322266,  -18.22031593322754,
                 7.372773170471191});
    executorch::aten::Tensor out1_expected = tf.make({0}, {});
    executorch::aten::Tensor out2_expected = tf.make({0}, {});
    op_native_batch_norm_legit_no_training_out(
        input,
        weight,
        bias,
        running_mean,
        running_var,
        momentum,
        eps,
        out0,
        out1,
        out2);
    if (DTYPE == executorch::aten::ScalarType::Half ||
        DTYPE == executorch::aten::ScalarType::BFloat16) {
      EXPECT_TENSOR_CLOSE_WITH_TOL(
          out0,
          out0_expected,
          4e-2,
          executorch::runtime::testing::internal::kDefaultAtol);
      EXPECT_TENSOR_CLOSE_WITH_TOL(
          out1,
          out1_expected,
          2e-2,
          executorch::runtime::testing::internal::kDefaultAtol);
      EXPECT_TENSOR_CLOSE_WITH_TOL(
          out2,
          out2_expected,
          2e-2,
          executorch::runtime::testing::internal::kDefaultAtol);
    } else {
      EXPECT_TENSOR_CLOSE(out0, out0_expected);
      EXPECT_TENSOR_CLOSE(out1, out1_expected);
      EXPECT_TENSOR_CLOSE(out2, out2_expected);
    }
  }
};

class OpNativeBatchNormLegitOutTest : public OperatorTest {
 protected:
  ::std::tuple<
      executorch::aten::Tensor&,
      executorch::aten::Tensor&,
      executorch::aten::Tensor&>
  op_native_batch_norm_legit_out(
      const executorch::aten::Tensor& input,
      const std::optional<executorch::aten::Tensor>& weight,
      const std::optional<executorch::aten::Tensor>& bias,
      executorch::aten::Tensor& running_mean,
      executorch::aten::Tensor& running_var,
      bool training,
      double momentum,
      double eps,
      executorch::aten::Tensor& out0,
      executorch::aten::Tensor& out1,
      executorch::aten::Tensor& out2) {
    executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext context{};
    return torch::executor::aten::_native_batch_norm_legit_outf(
        context,
        input,
        weight,
        bias,
        running_mean,
        running_var,
        training,
        momentum,
        eps,
        out0,
        out1,
        out2);
  }
};

class OpNativeBatchNormLegitNoStatsOutTest : public OperatorTest {
 protected:
  ::std::tuple<
      executorch::aten::Tensor&,
      executorch::aten::Tensor&,
      executorch::aten::Tensor&>
  op_native_batch_norm_legit_no_stats_out(
      const executorch::aten::Tensor& input,
      const std::optional<executorch::aten::Tensor>& weight,
      const std::optional<executorch::aten::Tensor>& bias,
      bool training,
      double momentum,
      double eps,
      executorch::aten::Tensor& out0,
      executorch::aten::Tensor& out1,
      executorch::aten::Tensor& out2) {
    return torch::executor::aten::_native_batch_norm_legit_outf(
        context_,
        input,
        weight,
        bias,
        training,
        momentum,
        eps,
        out0,
        out1,
        out2);
  }

  template <executorch::aten::ScalarType DTYPE>
  void test_2d_dtype() {
    torch::executor::testing::TensorFactory<DTYPE> tf;

    executorch::aten::Tensor input =
        tf.make({3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
    std::optional<executorch::aten::Tensor> weight =
        std::optional<executorch::aten::Tensor>();
    std::optional<executorch::aten::Tensor> bias =
        std::optional<executorch::aten::Tensor>();
    bool training = true;
    double momentum = 1e-3;
    double eps = 1e-5;
    executorch::aten::Tensor out0 = tf.zeros({3, 4});
    executorch::aten::Tensor out1 = tf.zeros({4});
    executorch::aten::Tensor out2 = tf.zeros({4});
    executorch::aten::Tensor out0_expected = tf.make(
        {3, 4},
        {-0.98058063,
         -1.03422451,
         -1.06904495,
         -1.09332705,
         -0.39223224,
         -0.31822300,
         -0.26726127,
         -0.23017406,
         1.37281299,
         1.35244739,
         1.33630610,
         1.32350123});
    executorch::aten::Tensor out1_expected =
        tf.make({4}, {26.66666603, 35.66666794, 46.66666794, 59.66666794});
    executorch::aten::Tensor out2_expected =
        tf.make({4}, {0.03677177, 0.02983340, 0.02505574, 0.02157882});
    op_native_batch_norm_legit_no_stats_out(
        input, weight, bias, training, momentum, eps, out0, out1, out2);
    if (DTYPE == executorch::aten::ScalarType::Half ||
        DTYPE == executorch::aten::ScalarType::BFloat16) {
      EXPECT_TENSOR_CLOSE_WITH_TOL(
          out0,
          out0_expected,
          2e-2,
          executorch::runtime::testing::internal::kDefaultAtol);
      EXPECT_TENSOR_CLOSE_WITH_TOL(
          out1,
          out1_expected,
          1e-2,
          executorch::runtime::testing::internal::kDefaultAtol);
      EXPECT_TENSOR_CLOSE_WITH_TOL(
          out2,
          out2_expected,
          2e-2,
          executorch::runtime::testing::internal::kDefaultAtol);
    } else {
      EXPECT_TENSOR_CLOSE(out0, out0_expected);
      EXPECT_TENSOR_CLOSE(out1, out1_expected);
      EXPECT_TENSOR_CLOSE(out2, out2_expected);
    }
  }
};

TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest2D) {
#define TEST_ENTRY(ctype, dtype) \
  test_2d_dtype<executorch::aten::ScalarType::dtype>();
  ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY)
#undef TEST_ENTRY
}

TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest3D) {
  torch::executor::testing::TensorFactory<executorch::aten::ScalarType::Float>
      tfFloat;

  executorch::aten::Tensor input = tfFloat.make(
      {4, 7, 5}, {5.277339935302734,  5.94276237487793,     6.543086051940918,
                  2.411264181137085,  8.980886459350586,    2.7123653888702393,
                  9.466896057128906,  9.324702262878418,    1.9848430156707764,
                  8.388091087341309,  1.5069717168807983,   5.350819110870361,
                  1.727534532546997,  4.913003444671631,    2.555372714996338,
                  4.321412563323975,  1.107364296913147,    6.048641681671143,
                  9.496582984924316,  0.9668296575546265,   0.8103430271148682,
                  8.187652587890625,  9.455179214477539,    0.5739009380340576,
                  3.550161838531494,  1.5362483263015747,   7.338945388793945,
                  3.583885431289673,  6.5086517333984375,   0.9027481079101562,
                  0.8805221319198608, 3.983092784881592,    5.43976354598999,
                  9.080245971679688,  2.602390766143799,    2.1537625789642334,
                  3.2551045417785645, 7.098634719848633,    8.135055541992188,
                  7.457048416137695,  5.3438568115234375,   3.7822632789611816,
                  3.4284191131591797, 6.144853115081787,    9.79615592956543,
                  5.735219955444336,  2.5468051433563232,   8.514262199401855,
                  3.775507926940918,  8.327726364135742,    4.772212505340576,
                  7.100861072540283,  3.477612018585205,    9.359293937683105,
                  5.203947067260742,  3.6150975227355957,   6.159048557281494,
                  0.9919929504394531, 1.6809028387069702,   0.3627735376358032,
                  1.8791186809539795, 4.037001132965088,    8.129783630371094,
                  4.79802131652832,   2.9911656379699707,   8.659820556640625,
                  7.378345489501953,  3.6833512783050537,   2.4555420875549316,
                  8.481515884399414,  3.733121156692505,    6.075705528259277,
                  6.900073051452637,  6.380939960479736,    3.204977512359619,
                  2.058135986328125,  4.60728120803833,     7.737727165222168,
                  5.3178815841674805, 9.224492073059082,    4.838874340057373,
                  2.717348337173462,  1.8555694818496704,   1.856197714805603,
                  7.189084053039551,  5.280246257781982,    7.550882816314697,
                  0.6145977973937988, 6.764681816101074,    4.217874526977539,
                  0.89302659034729,   2.4634499549865723,   3.51415753364563,
                  5.038887977600098,  4.948186874389648,    8.326996803283691,
                  8.919670104980469,  4.45585298538208,     0.5209791660308838,
                  4.2513017654418945, 0.047875046730041504, 2.453791618347168,
                  6.113187789916992,  5.47722053527832,     7.524778842926025,
                  0.3724473714828491, 2.6570069789886475,   9.420238494873047,
                  4.650344371795654,  4.206380844116211,    1.2107867002487183,
                  3.3689606189727783, 4.082674980163574,    5.31553840637207,
                  4.759864807128906,  5.461820602416992,    2.0690488815307617,
                  9.234517097473145,  1.6740238666534424,   3.492245674133301,
                  9.844581604003906,  4.278226852416992,    2.9611783027648926,
                  9.626322746276855,  7.756594657897949,    3.4873299598693848,
                  6.345180988311768,  5.55388069152832,     8.535417556762695,
                  8.509242057800293,  8.684778213500977,    3.784114122390747,
                  3.887125253677368,  9.278786659240723,    6.742891311645508,
                  5.01821756362915,   2.326876640319824,    7.939553737640381,
                  3.2622408866882324, 3.829448699951172});
  std::optional<executorch::aten::Tensor> weight =
      std::optional<executorch::aten::Tensor>(tfFloat.make(
          {7},
          {0.5193436145782471,
           4.531304836273193,
           8.960723876953125,
           8.598731994628906,
           2.6848177909851074,
           7.309220314025879,
           2.2476916313171387}));
  std::optional<executorch::aten::Tensor> bias =
      std::optional<executorch::aten::Tensor>(tfFloat.make(
          {7},
          {4.643010139465332,
           0.2791440486907959,
           3.6721653938293457,
           3.918765068054199,
           2.6499342918395996,
           5.721188545227051,
           5.901060104370117}));
  executorch::aten::Tensor running_mean = tfFloat.make(
      {7},
      {5.818909645080566,
       5.325511932373047,
       7.094021797180176,
       4.9185566902160645,
       5.608961582183838,
       3.7719011306762695,
       6.7734270095825195});
  executorch::aten::Tensor running_var = tfFloat.make(
      {7},
      {8.8593168258667,
       3.440363883972168,
       7.105681896209717,
       1.0423260927200317,
       6.756608009338379,
       4.527579307556152,
       2.022289752960205});
  double momentum = 0.1;
  double eps = 0;
  executorch::aten::Tensor out0 = tfFloat.zeros({4, 7, 5});
  executorch::aten::Tensor out1 = tfFloat.zeros({0});
  executorch::aten::Tensor out2 = tfFloat.zeros({0});
  executorch::aten::Tensor out0_expected = tfFloat.make(
      {4, 7, 5}, {4.5485148429870605,  4.664620399475098,   4.76936674118042,
                  4.048431873321533,   5.194723129272461,   -6.104737281799316,
                  10.396490097045898,  10.049112319946289,  -7.8820648193359375,
                  7.760983943939209,   -15.109009742736816, -2.1877059936523438,
                  -14.367575645446777, -3.659447431564331,  -11.584752082824707,
                  -1.1105821132659912, -28.180377960205078, 13.436722755432129,
                  42.476444244384766,  -29.3640079498291,   -2.306469440460205,
                  5.313416481018066,   6.622621059417725,   -2.5506861209869385,
                  0.5234383940696716,  -1.9584782123565674, 17.97430419921875,
                  5.075337886810303,   15.122170448303223,  -4.134607791900635,
                  -3.413116931915283,  1.4907281398773193,  3.793105363845825,
                  9.547160148620605,   -0.69157475233078,   4.003501892089844,
                  4.1956682205200195,  4.8663010597229,     5.047139644622803,
                  4.92883825302124,    0.3239605128765106,  -3.4909913539886475,
                  -4.3554277420043945, 2.2807836532592773,  11.20086669921875,
                  -0.8955214619636536, -11.613553047180176, 8.446381568908691,
                  -7.483201026916504,  7.819331169128418,   2.686206579208374,
                  22.29886817932129,   -8.217354774475098,  41.320152282714844,
                  6.322420597076416,   0.5905092358589172,  3.218108892440796,
                  -2.1188466548919678, -1.4072843790054321, -2.7687556743621826,
                  -0.7806879878044128, 6.63183069229126,    20.690902709960938,
                  9.246002197265625,   3.039292335510254,   8.882646560668945,
                  6.857179164886475,   1.016964316368103,   -0.9236800670623779,
                  8.600822448730469,   4.279074192047119,   4.687816619873047,
                  4.831655502319336,   4.741075038909912,   4.1869215965271,
                  -7.7030110359191895, -1.475483775138855,  6.172153472900391,
                  0.2605033814907074,  9.804300308227539,   -3.9086363315582275,
                  -11.040262222290039, -13.937179565429688, -13.935067176818848,
                  3.991722345352173,   6.965037822723389,   26.08910369873047,
                  -32.330623626708984, 19.467453002929688,  -1.9826143980026245,
                  -2.221067190170288,  -0.5990060567855835, 0.48625022172927856,
                  2.0611159801483154,  1.9674323797225952,  21.36834716796875,
                  23.404233932495117,  8.070624351501465,   -5.446018218994141,
                  7.367972373962402,   -4.729177951812744,  -0.9264468550682068,
                  4.8575029373168945,  3.852308988571167,   7.08862829208374,
                  3.6926915645599365,  4.091310024261475,   5.271382808685303,
                  4.439114570617676,   4.361649990081787,   -9.77307415008545,
                  -4.5006842613220215, -2.757089614868164,  0.25477901101112366,
                  -1.1027240753173828, -1.8145684003829956, -13.21955680847168,
                  10.867557525634766,  -14.547454833984375, -8.435402870178223,
                  45.407405853271484,  -1.4743067026138306, -12.566932678222656,
                  43.569156646728516,  27.821678161621094,  0.45854052901268005,
                  3.4103615283966064,  2.5930423736572266,  5.672616004943848,
                  5.645579814910889,   22.59735870361328,   5.76314115524292,
                  6.116993427276611,   24.63783073425293,   15.926804542541504,
                  3.1268203258514404,  -1.1270453929901123, 7.744210720062256,
                  0.3513677716255188,  1.2478822469711304});
  executorch::aten::Tensor out1_expected = tfFloat.make({0}, {});
  executorch::aten::Tensor out2_expected = tfFloat.make({0}, {});
  op_native_batch_norm_legit_no_training_out(
      input,
      weight,
      bias,
      running_mean,
      running_var,
      momentum,
      eps,
      out0,
      out1,
      out2);
  EXPECT_TENSOR_CLOSE(out0, out0_expected);
  EXPECT_TENSOR_CLOSE(out1, out1_expected);
  EXPECT_TENSOR_CLOSE(out2, out2_expected);
}

TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest4D) {
  torch::executor::testing::TensorFactory<executorch::aten::ScalarType::Float>
      tfFloat;

  executorch::aten::Tensor input = tfFloat.make(
      {2, 4, 5, 5},
      {8.0573148727417,     2.2901253700256348,  2.783101797103882,
       2.095468044281006,   6.389344215393066,   6.702191352844238,
       6.535638809204102,   1.8584740161895752,  6.037202835083008,
       7.588045120239258,   0.7384824752807617,  7.876931190490723,
       2.198972225189209,   0.8259981870651245,  8.311962127685547,
       8.748727798461914,   6.331905841827393,   1.6120970249176025,
       7.9793596267700195,  9.730956077575684,   8.96406078338623,
       7.34755802154541,    1.0760420560836792,  6.761768341064453,
       3.18643856048584,    0.32129645347595215, 5.146165370941162,
       1.9008630514144897,  3.1616015434265137,  3.077312707901001,
       7.684902667999268,   3.1405091285705566,  6.699800491333008,
       0.7976526021957397,  1.5945738554000854,  0.7354140281677246,
       9.370306015014648,   4.1550726890563965,  4.169681549072266,
       5.389268398284912,   6.883472442626953,   8.881608963012695,
       7.600193500518799,   8.894989967346191,   1.7032986879348755,
       8.945396423339844,   1.6370415687561035,  7.708703994750977,
       7.488667964935303,   7.315606594085693,   5.349757194519043,
       6.913224220275879,   3.6051642894744873,  3.8086843490600586,
       3.2311654090881348,  4.91132926940918,    1.331128478050232,
       2.73335337638855,    0.46345293521881104, 8.168035507202148,
       8.112630844116211,   9.38737678527832,    8.532957077026367,
       8.641634941101074,   7.772867679595947,   3.7504279613494873,
       1.1857783794403076,  7.61868953704834,    9.75157642364502,
       3.6754441261291504,  2.468808174133301,   6.380059719085693,
       6.197269439697266,   7.659857273101807,   6.72884464263916,
       9.320260047912598,   1.9144713878631592,  6.228992462158203,
       2.7658307552337646,  6.0448317527771,     1.1033517122268677,
       7.482324600219727,   4.140635013580322,   0.4461771249771118,
       9.729606628417969,   7.259793758392334,   7.154001235961914,
       8.320201873779297,   0.8773839473724365,  6.855964660644531,
       4.737044334411621,   4.0600152015686035,  6.474225044250488,
       0.8523398637771606,  3.7826621532440186,  5.399431228637695,
       0.17764925956726074, 5.480880260467529,   1.5790224075317383,
       7.965246200561523,   0.919603705406189,   6.623161315917969,
       6.618031978607178,   1.6051316261291504,  0.07815778255462646,
       7.8453497886657715,  2.781987190246582,   0.28109610080718994,
       9.149931907653809,   7.448637962341309,   5.52522087097168,
       4.095173358917236,   6.3080902099609375,  5.314402103424072,
       8.845094680786133,   6.3725972175598145,  1.9547373056411743,
       5.2839508056640625,  3.5294246673583984,  3.570653200149536,
       2.5026822090148926,  0.5656778812408447,  8.309356689453125,
       0.7813519239425659,  2.366170883178711,   9.322799682617188,
       0.5455368757247925,  0.7133877277374268,  6.577077388763428,
       8.393207550048828,   5.753355979919434,   7.874646186828613,
       6.351865768432617,   7.233908176422119,   7.866637706756592,
       5.024176120758057,   5.872377872467041,   0.3430730104446411,
       1.7413997650146484,  7.130331993103027,   7.7794294357299805,
       8.817843437194824,   4.551261901855469,   4.685880661010742,
       0.4518568515777588,  3.2571589946746826,  9.467324256896973,
       6.947274208068848,   1.1890357732772827,  4.438136100769043,
       0.790744423866272,   0.9745275974273682,  2.3840129375457764,
       9.280584335327148,   7.309266090393066,   6.359057903289795,
       4.779758930206299,   6.523046970367432,   2.581796169281006,
       7.4173126220703125,  5.556275844573975,   6.3515143394470215,
       9.909261703491211,   4.264077663421631,   1.5390598773956299,
       6.409996032714844,   9.431000709533691,   6.966275215148926,
       6.593939781188965,   9.72049331665039,    8.224472045898438,
       1.1502748727798462,  9.417522430419922,   2.0071351528167725,
       7.99619722366333,    5.217411518096924,   0.5482637882232666,
       3.6407017707824707,  9.56554889678955,    5.932462215423584,
       8.26833724975586,    2.5603179931640625,  7.974213600158691,
       6.683809280395508,   5.0010175704956055,  8.93687915802002,
       4.7291178703308105,  1.1585253477096558,  2.50417423248291,
       3.685148239135742,   0.36632418632507324, 7.834067344665527,
       9.173870086669922,   3.781676769256592,   5.6734232902526855,
       3.301741600036621,   1.3799077272415161,  8.990988731384277,
       2.2520315647125244,  2.483280897140503});
  std::optional<executorch::aten::Tensor> weight =
      std::optional<executorch::aten::Tensor>(tfFloat.make(
          {4},
          {1.8311285972595215,
           5.851841926574707,
           6.108979225158691,
           5.1755266189575195}));
  std::optional<executorch::aten::Tensor> bias =
      std::optional<executorch::aten::Tensor>(tfFloat.make(
          {4},
          {5.1375732421875,
           3.7950849533081055,
           2.406358242034912,
           5.785604476928711}));
  executorch::aten::Tensor running_mean = tfFloat.make(
      {4},
      {2.8203158378601074,
       3.1786017417907715,
       1.9189423322677612,
       1.8829244375228882});
  executorch::aten::Tensor running_var = tfFloat.make(
      {4},
      {1.4411485195159912,
       7.426868438720703,
       7.584629535675049,
       5.526189804077148});
  double momentum = 0.1;
  double eps = 0;
  executorch::aten::Tensor out0 = tfFloat.zeros({2, 4, 5, 5});
  executorch::aten::Tensor out1 = tfFloat.zeros({0});
  executorch::aten::Tensor out2 = tfFloat.zeros({0});
  executorch::aten::Tensor out0_expected = tfFloat.make(
      {2, 4, 5, 5},
      {13.125737190246582,   4.328856468200684,    5.080809593200684,
       4.031939506530762,    10.581527709960938,   11.058723449707031,
       10.804675102233887,   3.6704447269439697,   10.044395446777344,
       12.409944534301758,   1.962085485458374,    12.850592613220215,
       4.189817905426025,    2.095576047897339,    13.514159202575684,
       14.180371284484863,   10.493914604187012,   3.29463791847229,
       13.006829261779785,   15.678596496582031,   14.50882625579834,
       12.043122291564941,   2.476976156234741,    11.149598121643066,
       5.6960320472717285,   -2.340364456176758,   8.020005226135254,
       1.0514155626296997,   3.7585806846618652,   3.5775885581970215,
       13.47139835357666,    3.713289260864258,    11.35610294342041,
       -1.3174920082092285,  0.3937252461910248,   -1.4511359930038452,
       17.09044075012207,    5.891846656799316,    5.923215866088867,
       8.542016983032227,    11.75049877166748,    16.041067123413086,
       13.28950309753418,    16.069801330566406,   0.6271884441375732,
       16.178037643432617,   0.48491552472114563,  13.522506713867188,
       13.050026893615723,   12.678414344787598,   10.01660442352295,
       13.48469352722168,    6.146742343902588,    6.598191261291504,
       5.317136287689209,    9.044082641601562,    1.1024672985076904,
       4.212887763977051,    -0.8222138285636902,  16.26811981201172,
       16.145221710205078,   18.972867965698242,   17.077590942382812,
       17.318660736083984,   15.391557693481445,   6.468966484069824,
       0.7800511717796326,   15.049558639526367,   19.780736923217773,
       6.302637100219727,    3.626072645187378,    12.30202579498291,
       11.896559715270996,   15.140877723693848,   13.075701713562012,
       22.15976333618164,    5.855058670043945,    15.353979110717773,
       7.729425430297852,    14.948527336120605,   4.069284439086914,
       18.11333465576172,    10.756217002868652,   2.6224381923675537,
       23.06098747253418,    17.6234073638916,     17.390493392944336,
       19.958019256591797,   3.5717902183532715,   16.734331130981445,
       12.069281578063965,   10.578722953796387,   15.89388656616211,
       3.516652822494507,    9.968097686767578,    13.527603149414062,
       2.031242847442627,    13.70692253112793,    5.1165289878845215,
       19.176542282104492,   2.2383556365966797,   10.938176155090332,
       10.930352210998535,   3.284013509750366,    0.9548709392547607,
       12.802419662475586,   5.079109191894531,    1.2644193172454834,
       14.792341232299805,   12.19730281829834,    9.263452529907227,
       7.082154750823975,    10.457588195800781,   8.94188404083252,
       14.327363014221191,   10.55598258972168,    3.8172783851623535,
       8.895435333251953,    6.2191996574401855,   6.282087326049805,
       4.653076171875,       1.6985011100769043,   13.510184288024902,
       2.027475595474243,    4.444851398468018,    16.98843002319336,
       -1.8588563203811646,  -1.4984327554702759,  11.092581748962402,
       14.992330551147461,   9.323816299438477,    13.87883186340332,
       10.608987808227539,   12.502984046936035,   13.861635208129883,
       7.758059501647949,    9.579390525817871,    -2.2936041355133057,
       0.7090023756027222,   12.280576705932617,   13.6743745803833,
       15.904145240783691,   6.742578029632568,    7.031642436981201,
       -2.060014009475708,   3.9637696743011475,   17.298765182495117,
       11.887499809265137,   -0.47708070278167725, 6.499664306640625,
       -0.09621463716030121, 0.3114539086818695,   3.4379796981811523,
       18.735980987548828,   14.363194465637207,   12.255439758300781,
       8.752232551574707,    12.619200706481934,   3.8767030239105225,
       14.602864265441895,   10.47470474243164,    12.238706588745117,
       20.13051414489746,    7.608346462249756,    1.5637015104293823,
       12.368430137634277,   19.06963539123535,    13.602371215820312,
       12.77645492553711,    19.711788177490234,   16.393308639526367,
       0.7012971639633179,   19.039737701416016,   2.601987838745117,
       15.886947631835938,   13.12686538696289,    2.847193956375122,
       9.655555725097656,    22.69979476928711,    14.701132774353027,
       19.843833923339844,   7.276965141296387,    19.196285247802734,
       16.355310440063477,   12.6504487991333,     21.315706253051758,
       12.051830291748047,   4.190755844116211,    7.153357982635498,
       9.753409385681152,    2.4466326236724854,   18.887737274169922,
       21.83746910095215,    9.96592903137207,     14.130828857421875,
       8.909295082092285,    4.678154945373535,    21.43483543395996,
       6.598236560821533,    7.107358932495117});
  executorch::aten::Tensor out1_expected = tfFloat.make({0}, {});
  executorch::aten::Tensor out2_expected = tfFloat.make({0}, {});
  op_native_batch_norm_legit_no_training_out(
      input,
      weight,
      bias,
      running_mean,
      running_var,
      momentum,
      eps,
      out0,
      out1,
      out2);
  EXPECT_TENSOR_CLOSE(out0, out0_expected);
  EXPECT_TENSOR_CLOSE(out1, out1_expected);
  EXPECT_TENSOR_CLOSE(out2, out2_expected);
}

TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTestDouble) {
  torch::executor::testing::TensorFactory<executorch::aten::ScalarType::Double>
      tfDouble;

  executorch::aten::Tensor input = tfDouble.make(
      {3, 4, 3, 3},
      {0.09871780872344971, 5.7593607902526855,  4.542290687561035,
       9.888419151306152,   4.6276702880859375,  0.23040294647216797,
       5.160412311553955,   5.192661285400391,   7.774633407592773,
       3.82037353515625,    6.421841621398926,   1.3372838497161865,
       5.101180553436279,   3.166962146759033,   0.253373384475708,
       5.272202491760254,   0.8737403154373169,  9.0341796875,
       4.930244445800781,   5.145639896392822,   8.51688003540039,
       1.0039496421813965,  6.3629584312438965,  8.20095157623291,
       7.129164695739746,   6.775269031524658,   5.83862829208374,
       4.415182590484619,   9.107303619384766,   1.1548930406570435,
       4.394702434539795,   7.173308372497559,   1.648862361907959,
       3.040163516998291,   8.946229934692383,   8.740336418151855,
       7.152044773101807,   6.766063690185547,   8.682901382446289,
       2.8464317321777344,  8.757857322692871,   8.097877502441406,
       5.039367198944092,   1.713152527809143,   1.5446704626083374,
       7.220646858215332,   5.2453131675720215,  7.095609188079834,
       6.792170524597168,   5.975555896759033,   3.7161855697631836,
       2.0132927894592285,  3.0089516639709473,  1.4530837535858154,
       2.124783515930176,   1.3747084140777588,  0.4398918151855469,
       5.140370845794678,   0.16295194625854492, 5.689471244812012,
       9.149665832519531,   9.32123851776123,    1.5971916913986206,
       3.5363614559173584,  0.4872584342956543,  7.255306243896484,
       8.349767684936523,   0.977746844291687,   0.010267496109008789,
       9.964345932006836,   9.955519676208496,   2.3190832138061523,
       9.237786293029785,   4.200929641723633,   9.231035232543945,
       3.777331829071045,   4.507022857666016,   9.332846641540527,
       0.8198702335357666,  0.8076483011245728,  6.062283992767334,
       5.735506057739258,   6.782886505126953,   6.669310569763184,
       5.708680152893066,   7.5679931640625,     4.829475402832031,
       1.1562585830688477,  0.5352389812469482,  4.793148040771484,
       1.7251378297805786,  9.661691665649414,   7.695187568664551,
       2.569558620452881,   5.02672004699707,    4.213432312011719,
       0.4719752073287964,  3.2524518966674805,  4.827580451965332,
       1.7936384677886963,  1.8733304738998413,  9.386192321777344,
       2.442445755004883,   2.2374587059020996,  1.6268903017044067,
       1.9272565841674805,  0.04978537559509277, 5.165012359619141});
  std::optional<executorch::aten::Tensor> weight =
      std::optional<executorch::aten::Tensor>(tfDouble.make(
          {4},
          {5.4100823402404785,
           3.3440847396850586,
           0.9714162349700928,
           0.6811875104904175}));
  std::optional<executorch::aten::Tensor> bias =
      std::optional<executorch::aten::Tensor>(tfDouble.make(
          {4},
          {6.839208126068115,
           6.471728801727295,
           3.077871799468994,
           4.0067667961120605}));
  executorch::aten::Tensor running_mean = tfDouble.make(
      {4},
      {8.781468391418457,
       5.093882083892822,
       9.076446533203125,
       7.148240089416504});
  executorch::aten::Tensor running_var = tfDouble.make(
      {4},
      {1.0133814811706543,
       2.674386978149414,
       6.866252422332764,
       9.597100257873535});
  double momentum = 0.1;
  double eps = 0;
  executorch::aten::Tensor out0 = tfDouble.zeros({3, 4, 3, 3});
  executorch::aten::Tensor out1 = tfDouble.zeros({0});
  executorch::aten::Tensor out2 = tfDouble.zeros({0});
  executorch::aten::Tensor out0_expected = tfDouble.make(
      {3, 4, 3, 3},
      {-39.82401348817106,   -9.402336001242755,   -15.94316789328793,
       12.788231783114975,   -15.48431707375971,   -39.11630540562901,
       -12.621231365199568,  -12.447917505830254,  1.4282310938746887,
       3.8675726975224554,   9.187229957306027,    -1.2100164305929255,
       6.486653204105122,    2.53143305610451,     -3.4264695963512772,
       6.836370389027959,    -2.1579014884157313,  14.529114884449125,
       1.540793678645762,    1.6206449805000125,   2.870429566122108,
       0.08523948439124736,  2.07192874490631,     2.7533087138107413,
       2.3559763770206743,   2.2247803400083974,   1.8775493181436829,
       3.4058069855539905,   4.437536528655581,    2.688916473373198,
       3.401303695505533,    4.012278948950884,    2.797533181889745,
       3.1034601808347273,   4.402118755309629,    4.356845749255535,
       -1.9177122392967219,  -3.992068820138994,   6.30948495370292,
       -25.057127980919624,  6.712316477851606,    3.165423782960796,
       -13.271757354773879,  -31.14764712696726,   -32.05311088198505,
       10.820680738292133,   6.781385286813924,    10.564995283913401,
       9.94450345561426,     8.274634831663992,    3.654522124495513,
       0.1723322067015646,   2.2083225723239543,   -0.9732209832222724,
       0.5007544998438981,   0.2226870048863787,   -0.12386777243752074,
       1.6186916404984417,   -0.22653479260151907, 1.822253886542661,
       3.10501562425824,     3.1686209708035005,   0.3051659025884432,
       3.212566930139726,    2.542113280715248,    4.0303090947788895,
       4.270965334366817,    2.6499645871979483,   2.437229873043587,
       4.625987736153007,    4.624046970174033,    2.9449050525314044,
       9.29157194403522,     -17.777725500427557,  9.25529009664348,
       -20.05424357138422,   -16.132705822147077,  9.802449466893506,
       -35.948364280449816,  -36.01404792933805,   -7.774352749114295,
       7.783764743499358,    9.92551886693686,     9.69327114024373,
       7.728909325429857,    11.53095787270343,    5.931052201511217,
       -1.5801890954114401,  -2.850091828706626,   5.856767563411625,
       0.3525980358528258,   3.2948336043782858,   2.565812114767675,
       0.6656413209484225,   1.5765590689152436,   1.2750574158293795,
       -0.11197383213610211, 0.9188032005673643,   1.5027341303183273,
       2.829367353398184,    2.8468904728035618,   4.498860120209044,
       2.972030690911762,    2.9269570039352497,   2.792701843675069,
       2.858748044414565,    2.4459192831196264,   3.570683705559329});
  executorch::aten::Tensor out1_expected = tfDouble.make({0}, {});
  executorch::aten::Tensor out2_expected = tfDouble.make({0}, {});
  op_native_batch_norm_legit_no_training_out(
      input,
      weight,
      bias,
      running_mean,
      running_var,
      momentum,
      eps,
      out0,
      out1,
      out2);
  EXPECT_TENSOR_CLOSE(out0, out0_expected);
  EXPECT_TENSOR_CLOSE(out1, out1_expected);
  EXPECT_TENSOR_CLOSE(out2, out2_expected);
}

TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTestNoWeight) {
  torch::executor::testing::TensorFactory<executorch::aten::ScalarType::Float>
      tfFloat;

  executorch::aten::Tensor input = tfFloat.make(
      {4, 7, 5}, {4.1944355964660645,  3.537543296813965,  5.067144393920898,
                  9.735533714294434,   2.661299228668213,  0.43786585330963135,
                  8.926244735717773,   8.796754837036133,  2.2966713905334473,
                  7.153128623962402,   7.055768013000488,  0.3383845090866089,
                  0.8306580781936646,  2.355782985687256,  5.922069072723389,
                  1.9597464799880981,  2.731785774230957,  3.488309383392334,
                  9.926213264465332,   4.582781791687012,  1.2061834335327148,
                  9.317821502685547,   2.9511327743530273, 1.7717409133911133,
                  6.329389572143555,   0.844573974609375,  4.269064903259277,
                  3.9711995124816895,  0.7241052389144897, 2.239838123321533,
                  2.2850823402404785,  8.232909202575684,  5.126026153564453,
                  0.09984314441680908, 4.0997748374938965, 8.717041969299316,
                  2.4102187156677246,  8.769938468933105,  9.614383697509766,
                  4.630570411682129,   7.450488090515137,  2.7233500480651855,
                  5.878231525421143,   1.5304350852966309, 4.100255489349365,
                  3.448119640350342,   1.356201171875,     7.190479278564453,
                  4.431788444519043,   9.268322944641113,  7.564930438995361,
                  5.517428398132324,   6.40336799621582,   1.5203499794006348,
                  8.397398948669434,   9.415580749511719,  9.271242141723633,
                  6.522747993469238,   9.739391326904297,  3.8692879676818848,
                  4.59047794342041,    0.6365865468978882, 4.950358867645264,
                  2.111414670944214,   3.189572811126709,  2.893986701965332,
                  9.007704734802246,   1.0862338542938232, 4.761219024658203,
                  0.5109339952468872,  4.226720333099365,  9.338176727294922,
                  9.641677856445312,   8.222650527954102,  3.068296432495117,
                  3.6851234436035156,  2.7459187507629395, 9.115739822387695,
                  3.6909985542297363,  6.9336957931518555, 7.548684597015381,
                  9.266566276550293,   4.114157676696777,  1.0546678304672241,
                  1.881745457649231,   4.227387428283691,  1.3194853067398071,
                  6.739812850952148,   6.846013069152832,  7.290800094604492,
                  2.164156436920166,   3.4476895332336426, 7.013863563537598,
                  6.375678062438965,   2.4389731884002686, 5.257430553436279,
                  0.5499267578125,     5.771737098693848,  5.308223247528076,
                  0.2141815423965454,  5.413756370544434,  1.757289171218872,
                  9.780686378479004,   4.005618095397949,  7.078739166259766,
                  4.428859710693359,   2.348038673400879,  4.718813419342041,
                  1.896933913230896,   4.842776775360107,  6.077881813049316,
                  5.315243721008301,   5.951466083526611,  7.1189398765563965,
                  4.036149024963379,   9.996458053588867,  0.9982073307037354,
                  1.865202784538269,   0.5543112754821777, 4.5034308433532715,
                  4.392091751098633,   9.904728889465332,  2.8027725219726562,
                  8.39471435546875,    7.3801398277282715, 3.346047878265381,
                  1.2300896644592285,  6.925620079040527,  4.869058132171631,
                  0.06555616855621338, 2.475562572479248,  0.5495405197143555,
                  6.707937240600586,   0.946076512336731,  6.623589515686035,
                  5.87992000579834,    2.196932315826416,  8.085456848144531,
                  7.774395942687988,   8.86058235168457});
  std::optional<executorch::aten::Tensor> weight;
  std::optional<executorch::aten::Tensor> bias =
      std::optional<executorch::aten::Tensor>(tfFloat.make(
          {7},
          {3.2798612117767334,
           7.070205211639404,
           0.8457618951797485,
           8.21817684173584,
           4.158933162689209,
           9.13807201385498,
           5.7105536460876465}));
  executorch::aten::Tensor running_mean = tfFloat.make(
      {7},
      {8.596701622009277,
       8.133163452148438,
       1.8364977836608887,
       9.756494522094727,
       6.470483779907227,
       6.9614739418029785,
       5.237721920013428});
  executorch::aten::Tensor running_var = tfFloat.make(
      {7},
      {2.258641242980957,
       0.8535522222518921,
       9.372869491577148,
       8.911684036254883,
       9.814156532287598,
       0.5796539783477783,
       5.289167881011963});
  double momentum = 0.1;
  double eps = 0;
  executorch::aten::Tensor out0 = tfFloat.zeros({4, 7, 5});
  executorch::aten::Tensor out1 = tfFloat.zeros({0});
  executorch::aten::Tensor out2 = tfFloat.zeros({0});
  executorch::aten::Tensor out0_expected = tfFloat.make(
      {4, 7, 5}, {0.3506367802619934,  -0.08645286411046982, 0.9313285946846008,
                  4.037628650665283,   -0.669497013092041,   -1.259130597114563,
                  7.928630828857422,   7.788471698760986,    0.7528274059295654,
                  6.009422302246094,   2.5505621433258057,   0.3564245402812958,
                  0.5172187089920044,  1.0153789520263672,   2.180255651473999,
                  5.606414794921875,   5.865033149719238,    6.1184539794921875,
                  8.275029182434082,   6.485081672668457,    2.478527545928955,
                  5.067825794219971,   3.0355288982391357,   2.659057855606079,
                  4.113894939422607,   1.1037919521331787,   5.601710796356201,
                  5.210477828979492,   0.9455615878105164,   2.9364101886749268,
                  4.426696300506592,   7.012911796569824,    5.661986351013184,
                  3.47651743888855,    5.215754985809326,    3.3599345684051514,
                  -0.8365635275840759, 3.3951313495635986,   3.957016706466675,
                  0.6408365964889526,  6.331282138824463,    1.2146613597869873,
                  4.629482746124268,   -0.07654135674238205, 2.7050139904022217,
                  1.3721752166748047,  0.6888798475265503,   2.5945637226104736,
                  1.6934765577316284,  3.273261785507202,    7.484044551849365,
                  6.798170566558838,   7.094943046569824,    5.459225177764893,
                  7.762905597686768,   5.0990309715271,      5.052957057952881,
                  4.175616264343262,   5.202394008636475,    3.328611373901367,
                  6.0238728523254395,  0.8306096196174622,   6.496560573577881,
                  2.7677316665649414,  4.183845043182373,    4.691458225250244,
                  7.34980583190918,    3.90541672706604,     5.50336217880249,
                  3.655266761779785,   0.37211874127388,     3.7732315063476562,
                  3.9751780033111572,  3.0309712886810303,   -0.398685097694397,
                  2.255678176879883,   1.2390896081924438,   8.13373851776123,
                  2.2620372772216797,  5.771909713745117,    2.711566209793091,
                  3.2726879119873047,  1.5897270441055298,   0.590388298034668,
                  0.8605414032936096,  6.366031169891357,    5.391939163208008,
                  7.207645893096924,   7.243220806121826,    7.392216205596924,
                  2.784320116043091,   3.1940338611602783,   4.33238410949707,
                  4.128670692443848,   2.8720436096191406,   6.899885654449463,
                  0.716785728931427,   7.575404644012451,    6.966599464416504,
                  0.27579912543296814, 5.7870965003967285,   4.197203159332275,
                  7.685911178588867,   5.174814224243164,    6.511058807373047,
                  0.5066202878952026,  -0.8779374957084656,  0.699552595615387,
                  -1.178098201751709,  0.7820366024971008,   4.845582962036133,
                  4.020108699798584,   4.708751201629639,    5.972416877746582,
                  2.6356256008148193,  3.511096715927124,    0.5719462633132935,
                  0.8551380038261414,  0.4269539415836334,   1.7168775796890259,
                  6.421204090118408,   8.26783275604248,     5.88881254196167,
                  7.7620062828063965,  7.422143459320068,    3.1615889072418213,
                  2.486158609390259,   4.304216384887695,    3.6477456092834473,
                  2.1144304275512695,  3.2460241317749023,   0.7162784337997437,
                  8.805062294006348,   1.2371110916137695,   8.694275856018066,
                  5.989792346954346,   4.388367176055908,    6.94879674911499,
                  6.813542366027832,   7.285834312438965});
  executorch::aten::Tensor out1_expected = tfFloat.make({0}, {});
  executorch::aten::Tensor out2_expected = tfFloat.make({0}, {});
  op_native_batch_norm_legit_no_training_out(
      input,
      weight,
      bias,
      running_mean,
      running_var,
      momentum,
      eps,
      out0,
      out1,
      out2);
  EXPECT_TENSOR_CLOSE(out0, out0_expected);
  EXPECT_TENSOR_CLOSE(out1, out1_expected);
  EXPECT_TENSOR_CLOSE(out2, out2_expected);
}

TEST_F(
    OpNativeBatchNormLegitNoTrainingOutTest,
    SampleAtomicTestNoWeightNoBias) {
  torch::executor::testing::TensorFactory<executorch::aten::ScalarType::Float>
      tfFloat;

  executorch::aten::Tensor input = tfFloat.make(
      {2, 4, 2, 2},
      {2.628833770751953,   7.391754150390625,  9.153281211853027,
       2.480319023132324,   6.5120697021484375, 5.680999755859375,
       9.440492630004883,   8.139138221740723,  5.618698596954346,
       0.21270036697387695, 8.981918334960938,  8.472748756408691,
       2.5718064308166504,  5.815331935882568,  0.08409619331359863,
       2.942138195037842,   1.8946051597595215, 9.46719741821289,
       0.5490684509277344,  2.2121663093566895, 5.5882368087768555,
       9.131031036376953,   5.822923183441162,  3.371715545654297,
       0.1542043685913086,  3.606675863265991,  2.65787410736084,
       5.136600494384766,   6.950716972351074,  6.051759719848633,
       7.304986953735352,   6.186429977416992});
  std::optional<executorch::aten::Tensor> weight;
  std::optional<executorch::aten::Tensor> bias;
  executorch::aten::Tensor running_mean = tfFloat.make(
      {4},
      {8.043643951416016,
       3.569627285003662,
       7.6375412940979,
       4.194377899169922});
  executorch::aten::Tensor running_var = tfFloat.make(
      {4},
      {7.512979507446289,
       0.0478285551071167,
       0.8684122562408447,
       1.9676220417022705});
  double momentum = 0.1;
  double eps = 0;
  executorch::aten::Tensor out0 = tfFloat.zeros({2, 4, 2, 2});
  executorch::aten::Tensor out1 = tfFloat.zeros({0});
  executorch::aten::Tensor out2 = tfFloat.zeros({0});
  executorch::aten::Tensor out0_expected = tfFloat.make(
      {2, 4, 2, 2},
      {-1.975500464439392,  -0.23783083260059357, 0.40483206510543823,
       -2.0296835899353027, 13.454400062561035,   9.65431022644043,
       26.844696044921875,  20.894216537475586,   -2.1664047241210938,
       -7.967539310455322,  1.4426401853561401,   0.896254301071167,
       -1.1567326784133911, 1.1555795669555664,   -2.9302234649658203,
       -0.8927228450775146, -2.2433712482452393,  0.5193589925765991,
       -2.734267234802246,  -2.127514362335205,   9.230148315429688,
       25.42967414855957,   10.303258895874023,   -0.9049565196037292,
       -8.03031063079834,   -4.325490474700928,   -5.343642234802246,
       -2.6837403774261475, 1.9649964570999146,   1.3241291046142578,
       2.2175559997558594,  1.4201356172561646});
  executorch::aten::Tensor out1_expected = tfFloat.make({0}, {});
  executorch::aten::Tensor out2_expected = tfFloat.make({0}, {});
  op_native_batch_norm_legit_no_training_out(
      input,
      weight,
      bias,
      running_mean,
      running_var,
      momentum,
      eps,
      out0,
      out1,
      out2);
  EXPECT_TENSOR_CLOSE(out0, out0_expected);
  EXPECT_TENSOR_CLOSE(out1, out1_expected);
  EXPECT_TENSOR_CLOSE(out2, out2_expected);
}

TEST_F(OpNativeBatchNormLegitOutTest, SampleAtomicTest2D) {
  torch::executor::testing::TensorFactory<executorch::aten::ScalarType::Float>
      tfFloat;

  executorch::aten::Tensor input = tfFloat.make(
      {4, 7}, {2.876736640930176,  7.67944860458374,   5.701690196990967,
               9.299789428710938,  3.023690700531006,  5.315116882324219,
               7.185585021972656,  6.911304473876953,  7.61051082611084,
               1.4963287115097046, 0.7381612062454224, 8.588483810424805,
               6.583977699279785,  8.831110000610352,  0.8165055513381958,
               7.087201118469238,  5.572513580322266,  4.446897983551025,
               4.444573402404785,  6.254056930541992,  5.906398296356201,
               9.971039772033691,  3.5423521995544434, 7.452159881591797,
               9.93700122833252,   1.8560808897018433, 1.524025797843933,
               7.3222975730896});
  std::optional<executorch::aten::Tensor> weight =
      std::optional<executorch::aten::Tensor>(tfFloat.make(
          {7},
          {8.287437438964844,
           8.227645874023438,
           6.65926456451416,
           9.436124801635742,
           4.119281768798828,
           8.593960762023926,
           2.3760855197906494}));
  std::optional<executorch::aten::Tensor> bias =
      std::optional<executorch::aten::Tensor>(tfFloat.make(
          {7},
          {7.824275970458984,
           6.84327507019043,
           8.354326248168945,
           8.773970603942871,
           3.89609694480896,
           3.0753469467163086,
           3.1105971336364746}));
  executorch::aten::Tensor running_mean = tfFloat.make(
      {7},
      {9.700226783752441,
       0.1234668493270874,
       7.527220249176025,
       8.993252754211426,
       0.4736626148223877,
       7.7135701179504395,
       5.12320613861084});
  executorch::aten::Tensor running_var = tfFloat.make(
      {7},
      {3.585531234741211,
       6.615292549133301,
       0.24084866046905518,
       5.175800323486328,
       0.5886000394821167,
       6.23909854888916,
       1.5029621124267578});
  bool training = false;
  double momentum = 0.1;
  double eps = 0;
  executorch::aten::Tensor out0 = tfFloat.zeros({4, 7});
  executorch::aten::Tensor out1 = tfFloat.zeros({0});
  executorch::aten::Tensor out2 = tfFloat.zeros({0});
  executorch::aten::Tensor out0_expected = tfFloat.make(
      {4, 7}, {-22.039867401123047, 31.014127731323242,  -16.416650772094727,
               10.04538631439209,   17.5877628326416,    -5.17673921585083,
               7.1078033447265625,  -4.381907939910889,  30.793603897094727,
               -73.48003387451172,  -25.46548080444336,  47.46636962890625,
               -0.8111140131950378, 10.29708194732666,   -31.056814193725586,
               29.119586944580078,  -18.16947364807129,  -10.082839965820312,
               25.216796875,        -1.9462348222732544, 4.628543376922607,
               9.00953483581543,    17.779958724975586,  7.335818767547607,
               12.688335418701172,  11.318607330322266,  -18.22031593322754,
               7.372773170471191});
  executorch::aten::Tensor out1_expected = tfFloat.make({0}, {});
  executorch::aten::Tensor out2_expected = tfFloat.make({0}, {});
  op_native_batch_norm_legit_out(
      input,
      weight,
      bias,
      running_mean,
      running_var,
      training,
      momentum,
      eps,
      out0,
      out1,
      out2);
  EXPECT_TENSOR_CLOSE(out0, out0_expected);
  EXPECT_TENSOR_CLOSE(out1, out1_expected);
  EXPECT_TENSOR_CLOSE(out2, out2_expected);
}

TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest2D) {
#define TEST_ENTRY(ctype, dtype) \
  test_2d_dtype<executorch::aten::ScalarType::dtype>();
  ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY)
#undef TEST_ENTRY
}

TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest3D) {
  torch::executor::testing::TensorFactory<executorch::aten::ScalarType::Float>
      tfFloat;

  executorch::aten::Tensor input = tfFloat.make(
      {2, 3, 4}, {0,   1,   4,   9,   16,  25,  36,  49,  64,  81,  100, 121,
                  144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529});
  std::optional<executorch::aten::Tensor> weight =
      std::optional<executorch::aten::Tensor>();
  std::optional<executorch::aten::Tensor> bias =
      std::optional<executorch::aten::Tensor>();
  bool training = true;
  double momentum = 1e-3;
  double eps = 1e-5;
  executorch::aten::Tensor out0 = tfFloat.zeros({2, 3, 4});
  executorch::aten::Tensor out1 = tfFloat.zeros({3});
  executorch::aten::Tensor out2 = tfFloat.zeros({3});
  executorch::aten::Tensor out0_expected = tfFloat.make(
      {2, 3, 4},
      {-1.01045656, -0.99964952, -0.96722847, -0.91319335, -1.08850884,
       -1.02468753, -0.94668359, -0.85449719, -1.12558389, -1.03595889,
       -0.93578988, -0.82507670, 0.54575467,  0.81593025,  1.10771990,
       1.42112350,  0.61339414,  0.84740579,  1.09560001,  1.35797679,
       0.64582670,  0.86198103,  1.08867943,  1.32592189});
  executorch::aten::Tensor out1_expected =
      tfFloat.make({3}, {93.5, 169.5, 277.5});
  executorch::aten::Tensor out2_expected =
      tfFloat.make({3}, {0.01080702, 0.00709126, 0.00527206});
  op_native_batch_norm_legit_no_stats_out(
      input, weight, bias, training, momentum, eps, out0, out1, out2);
  EXPECT_TENSOR_CLOSE(out0, out0_expected);
  EXPECT_TENSOR_CLOSE(out1, out1_expected);
  EXPECT_TENSOR_CLOSE(out2, out2_expected);
}

TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest4D) {
  torch::executor::testing::TensorFactory<executorch::aten::ScalarType::Float>
      tfFloat;

  executorch::aten::Tensor input =
      tfFloat.make({2, 3, 2, 2}, {0,   1,   4,   9,   16,  25,  36,  49,
                                  64,  81,  100, 121, 144, 169, 196, 225,
                                  256, 289, 324, 361, 400, 441, 484, 529});
  std::optional<executorch::aten::Tensor> weight =
      std::optional<executorch::aten::Tensor>(
          tfFloat.make({3}, {1.1, 0.7, 0.3}));
  std::optional<executorch::aten::Tensor> bias =
      std::optional<executorch::aten::Tensor>(
          tfFloat.make({3}, {1.7, 2.2, 3.3}));
  bool training = true;
  double momentum = 1e-3;
  double eps = 1e-5;
  executorch::aten::Tensor out0 = tfFloat.zeros({2, 3, 2, 2});
  executorch::aten::Tensor out1 = tfFloat.zeros({3});
  executorch::aten::Tensor out2 = tfFloat.zeros({3});
  executorch::aten::Tensor out0_expected = tfFloat.make(
      {2, 3, 2, 2},
      {0.58849782, 0.60038555, 0.63604873, 0.69548732, 1.43804383, 1.48271883,
       1.53732157, 1.60185206, 2.96232486, 2.98921227, 3.01926303, 3.05247688,
       2.30033016, 2.59752321, 2.91849184, 3.26323581, 2.62937593, 2.79318404,
       2.96691990, 3.15058374, 3.49374819, 3.55859423, 3.62660384, 3.69777656});
  executorch::aten::Tensor out1_expected =
      tfFloat.make({3}, {93.5, 169.5, 277.5});
  executorch::aten::Tensor out2_expected =
      tfFloat.make({3}, {0.01080702, 0.00709126, 0.00527206});
  op_native_batch_norm_legit_no_stats_out(
      input, weight, bias, training, momentum, eps, out0, out1, out2);
  EXPECT_TENSOR_CLOSE(out0, out0_expected);
  EXPECT_TENSOR_CLOSE(out1, out1_expected);
  EXPECT_TENSOR_CLOSE(out2, out2_expected);
}
